{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "H4OxaKTdoTEM"
   },
   "source": [
    "# Pipeline Example with BERT and TensorFlow Extended (TFX)\n",
    "![](img/tfx-overview.png)\n",
    "\n",
    "Based on the following gist:  https://gist.github.com/hanneshapke/e2dd30ece4c778d335e7d3fafd6ce4ff"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "o18JmOsdRTEw"
   },
   "source": [
    "## Motivation\n",
    "\n",
    "Instead of converting the input to a tranformer model into token ids on the client side, the model exported from this pipeline will allow the conversion on the server side.\n",
    "\n",
    "The pipeline takes advantage of the broad TensorFlow Eco system, including:\n",
    "* Loading the IMDB dataset via **TensorFlow Datasets**\n",
    "* Loading a pre-trained model via **tf.hub**\n",
    "* Manipulating the raw input data with **tf.text**\n",
    "* Building a simple model architecture with **Keras**\n",
    "* Composing the model pipeline with **TensorFlow Extended (TFX)**, e.g. TensorFlow Transform, TensorFlow Data Validation and then consuming the tf.Keras model with the latest Trainer component from TFX\n",
    "\n",
    "\n",
    "### Outline\n",
    "\n",
    "* Load the training data set\n",
    "* Create the TFX Pipeline\n",
    "* Export the trained Model\n",
    "* Test the exported Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Rinax0YJ_otk"
   },
   "source": [
    "# Project Setup\n",
    "\n",
    "## Install Required Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q pip --upgrade\n",
    "!pip install -q wrapt --upgrade --ignore-installed\n",
    "!pip install -q tensorflow==2.1.0\n",
    "!pip install -q transformers==2.8.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 829
    },
    "colab_type": "code",
    "id": "Sjjgiv0bM0hi",
    "outputId": "a33dfe60-c206-4bb2-8c3f-0ba4fd90e898",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pip install -Uq tfx==0.21.4\n",
    "!pip install -Uq tensorflow-text==2.1.1  # the tf-text version needs to match the tf version\n",
    "!pip install -Uq tensorflow-model-analysis==0.22.1\n",
    "!pip install -Uq tensorflow-data-validation==0.22.0\n",
    "!pip install -Uq tensorflow-transform==0.22.0\n",
    "!pip install -Uq tensorflow_hub==0.8.0\n",
    "!pip install -Uq tensorflow_datasets==3.2.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# _Ignore ERRORs ^^ ABOVE ^^_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Restart the kernel to pick up pip installed libraries\n",
    "from IPython.core.display import HTML\n",
    "HTML(\"<script>Jupyter.notebook.kernel.restart()</script>\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "oThi-x8xLlv-"
   },
   "outputs": [],
   "source": [
    "import glob\n",
    "import os\n",
    "import pprint\n",
    "import re\n",
    "import tempfile\n",
    "from shutil import rmtree\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import tensorflow as tf\n",
    "import tensorflow_data_validation as tfdv\n",
    "import tensorflow_hub as hub\n",
    "import tensorflow_model_analysis as tfma\n",
    "import tensorflow_transform as tft\n",
    "import tensorflow_transform.beam as tft_beam\n",
    "from tensorflow_transform.beam.tft_beam_io import transform_fn_io\n",
    "from tensorflow_transform.saved import saved_transform_io\n",
    "from tensorflow_transform.tf_metadata import (dataset_metadata, dataset_schema,\n",
    "                                              metadata_io, schema_utils)\n",
    "from tfx.components import (Evaluator, ExampleValidator, ImportExampleGen,\n",
    "                            ModelValidator, Pusher, ResolverNode, SchemaGen,\n",
    "                            StatisticsGen, Trainer, Transform)\n",
    "from tfx.components.base import executor_spec\n",
    "from tfx.components.trainer.executor import GenericExecutor\n",
    "from tfx.dsl.experimental import latest_blessed_model_resolver\n",
    "from tfx.proto import evaluator_pb2, example_gen_pb2, pusher_pb2, trainer_pb2\n",
    "from tfx.types import Channel\n",
    "from tfx.types.standard_artifacts import Model, ModelBlessing\n",
    "from tfx.utils.dsl_utils import external_input\n",
    "\n",
    "import tensorflow_datasets as tfds\n",
    "import tensorflow_model_analysis as tfma\n",
    "import tensorflow_text as text\n",
    "\n",
    "from tfx.orchestration.experimental.interactive.interactive_context import \\\n",
    "    InteractiveContext\n",
    "\n",
    "%load_ext tfx.orchestration.experimental.interactive.notebook_extensions.skip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "q4OMItPSLnDj"
   },
   "source": [
    "## Download the IMDB Dataset from TensorFlow Datasets\n",
    "\n",
    "For our demo example, we are using the [IMDB data set](https://www.kaggle.com/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews) to train a sentiment model based on the pre-trained BERT model. The data set is provided through [TensorFlow Datasets](https://www.tensorflow.org/datasets). Our ML pipeline can read TFRecords, however it expects only TFRecord files in the data folder. That is the reason why we need to delete the additional files provided by TFDS."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p ./content/tfds/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "KjWjnzPGKjIk"
   },
   "outputs": [],
   "source": [
    "def clean_before_download(base_data_dir):\n",
    "    rmtree(base_data_dir)\n",
    "    \n",
    "def delete_unnecessary_files(base_path):\n",
    "    os.remove(base_path + \"dataset_info.json\")\n",
    "    os.remove(base_path + \"label.labels.txt\")\n",
    "    \n",
    "    counter = 2\n",
    "    for f in glob.glob(base_path + \"imdb_reviews-unsupervised.*\"):\n",
    "        os.remove(f)\n",
    "        counter += 1\n",
    "    print(f\"Deleted {counter} files\")\n",
    "\n",
    "def get_dataset(name='imdb_reviews', version=\"1.0.0\"):\n",
    "\n",
    "    base_data_dir = \"./content/tfds/\"\n",
    "    config=\"plain_text\"\n",
    "    version=\"1.0.0\"\n",
    "\n",
    "    clean_before_download(base_data_dir)\n",
    "    tfds.disable_progress_bar()\n",
    "    builder = tfds.text.IMDBReviews(data_dir=base_data_dir, \n",
    "                                    config=config, \n",
    "                                    version=version)\n",
    "    download_config = tfds.download.DownloadConfig(\n",
    "        download_mode=tfds.GenerateMode.FORCE_REDOWNLOAD)\n",
    "    builder.download_and_prepare(download_config=download_config)\n",
    "\n",
    "    base_tfrecords_filename = os.path.join(base_data_dir, \"imdb_reviews\", config, version, \"\")\n",
    "    train_tfrecords_filename = base_tfrecords_filename + \"imdb_reviews-train*\"\n",
    "    test_tfrecords_filename = base_tfrecords_filename + \"imdb_reviews-test*\"\n",
    "    label_filename = os.path.join(base_tfrecords_filename, \"label.labels.txt\")\n",
    "    labels = [label.rstrip('\\n') for label in open(label_filename)]\n",
    "    delete_unnecessary_files(base_tfrecords_filename)\n",
    "    return (train_tfrecords_filename, test_tfrecords_filename), labels\n",
    "\n",
    "tfrecords_filenames, labels = get_dataset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RDtPNfOwriT8"
   },
   "source": [
    "## Helper function to load the BERT model as Keras layer\n",
    "\n",
    "In our pipeline components, we are reusing the BERT Layer from tf.hub in two places\n",
    "* in the model architecture when we define our Keras model\n",
    "* in our preprocessing function when we extract the BERT settings (casing and vocab file path) to reuse the settings during the tokenization "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Tre_oQu0rlrU"
   },
   "outputs": [],
   "source": [
    "%%skip_for_export\n",
    "%%writefile bert.py\n",
    "\n",
    "import tensorflow_hub as hub\n",
    "\n",
    "BERT_TFHUB_URL = \"https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2\"\n",
    "\n",
    "def load_bert_layer(model_url=BERT_TFHUB_URL):\n",
    "    # Load the pre-trained BERT model as layer in Keras\n",
    "    bert_layer = hub.KerasLayer(\n",
    "        handle=model_url,\n",
    "        trainable=True)\n",
    "    return bert_layer\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "k-5QGnm_lFJD"
   },
   "source": [
    "# TFX Pipeline\n",
    "\n",
    "The TensorFlow Extended Pipeline is more or less following the example setup shown here. We'll only note deviations from the original setup."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "arPCEBYEFqEr"
   },
   "source": [
    "## Initializing the Interactive TFX Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "sBO0T3D5kkOt"
   },
   "outputs": [],
   "source": [
    "context = InteractiveContext()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "qo2Q-c_ynL2x"
   },
   "source": [
    "## Loading the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "W8GqUHwAKm6j"
   },
   "outputs": [],
   "source": [
    "output = example_gen_pb2.Output(\n",
    "             split_config=example_gen_pb2.SplitConfig(splits=[\n",
    "                 example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=45),\n",
    "                 example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=5)\n",
    "             ]))\n",
    "# Load the data from our prepared TFDS folder\n",
    "examples = external_input(\"./content/tfds/imdb_reviews/plain_text/1.0.0\")\n",
    "example_gen = ImportExampleGen(input=examples, output_config=output)\n",
    "\n",
    "context.run(example_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "iu2ejZTWK-E5"
   },
   "outputs": [],
   "source": [
    "%%skip_for_export\n",
    "\n",
    "for artifact in example_gen.outputs['examples'].get():\n",
    "    print(artifact.uri)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "nE9VL-pmF6L_"
   },
   "source": [
    "## TensorFlow Data Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "EglytaKVLQzr"
   },
   "outputs": [],
   "source": [
    "%%skip_for_export\n",
    "\n",
    "statistics_gen = StatisticsGen(\n",
    "    examples=example_gen.outputs['examples'])\n",
    "context.run(statistics_gen)\n",
    "\n",
    "context.show(statistics_gen.outputs['statistics'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "IBYoEPhBeQUi"
   },
   "outputs": [],
   "source": [
    "%%skip_for_export\n",
    "\n",
    "schema_gen = SchemaGen(\n",
    "    statistics=statistics_gen.outputs['statistics'],\n",
    "    infer_feature_shape=True)\n",
    "context.run(schema_gen)\n",
    "\n",
    "context.show(schema_gen.outputs['schema'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "bl2gkqytjr0w"
   },
   "outputs": [],
   "source": [
    "%%skip_for_export\n",
    "\n",
    "# check the data schema for the type of input tensors\n",
    "tfdv.load_schema_text(schema_gen.outputs['schema'].get()[0].uri + \"/schema.pbtxt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "IWRswNYye6So"
   },
   "outputs": [],
   "source": [
    "%%skip_for_export\n",
    "\n",
    "example_validator = ExampleValidator(\n",
    "    statistics=statistics_gen.outputs['statistics'],\n",
    "    schema=schema_gen.outputs['schema'])\n",
    "context.run(example_validator)\n",
    "\n",
    "context.show(example_validator.outputs['anomalies'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "_zqjzTx2s5HS"
   },
   "source": [
    "## TensorFlow Transform\n",
    "\n",
    "This is where we perform the BERT processing. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "K91irJq7q6vC"
   },
   "outputs": [],
   "source": [
    "%%skip_for_export\n",
    "%%writefile transform.py\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow_text as text\n",
    "\n",
    "from bert import load_bert_layer\n",
    "\n",
    "MAX_SEQ_LEN = 64  # max number is 512\n",
    "do_lower_case = load_bert_layer().resolved_object.do_lower_case.numpy()\n",
    "\n",
    "def preprocessing_fn(inputs):\n",
    "    \"\"\"Preprocess input column of text into transformed columns of.\n",
    "        * input token ids\n",
    "        * input mask\n",
    "        * input type ids\n",
    "    \"\"\"\n",
    "\n",
    "    CLS_ID = tf.constant(101, dtype=tf.int64)\n",
    "    SEP_ID = tf.constant(102, dtype=tf.int64)\n",
    "    PAD_ID = tf.constant(0, dtype=tf.int64)\n",
    "\n",
    "    vocab_file_path = load_bert_layer().resolved_object.vocab_file.asset_path\n",
    "    \n",
    "    bert_tokenizer = text.BertTokenizer(vocab_lookup_table=vocab_file_path, \n",
    "                                        token_out_type=tf.int64, \n",
    "                                        lower_case=do_lower_case) \n",
    "    \n",
    "    def tokenize_text(text, sequence_length=MAX_SEQ_LEN):\n",
    "        \"\"\"\n",
    "        Perform the BERT preprocessing from text -> input token ids\n",
    "        \"\"\"\n",
    "\n",
    "        # convert text into token ids\n",
    "        tokens = bert_tokenizer.tokenize(text)\n",
    "        \n",
    "        # flatten the output ragged tensors \n",
    "        tokens = tokens.merge_dims(1, 2)[:, :sequence_length]\n",
    "        \n",
    "        # Add start and end token ids to the id sequence\n",
    "        start_tokens = tf.fill([tf.shape(text)[0], 1], CLS_ID)\n",
    "        end_tokens = tf.fill([tf.shape(text)[0], 1], SEP_ID)\n",
    "        tokens = tokens[:, :sequence_length - 2]\n",
    "        tokens = tf.concat([start_tokens, tokens, end_tokens], axis=1)\n",
    "\n",
    "        # truncate sequences greater than MAX_SEQ_LEN\n",
    "        tokens = tokens[:, :sequence_length]\n",
    "\n",
    "        # pad shorter sequences with the pad token id\n",
    "        tokens = tokens.to_tensor(default_value=PAD_ID)\n",
    "        pad = sequence_length - tf.shape(tokens)[1]\n",
    "        tokens = tf.pad(tokens, [[0, 0], [0, pad]], constant_values=PAD_ID)\n",
    "\n",
    "        # and finally reshape the word token ids to fit the output \n",
    "        # data structure of TFT  \n",
    "        return tf.reshape(tokens, [-1, sequence_length])\n",
    "\n",
    "    def preprocess_bert_input(text):\n",
    "        \"\"\"\n",
    "        Convert input text into the input_word_ids, input_mask, input_type_ids\n",
    "        \"\"\"\n",
    "        input_word_ids = tokenize_text(text)\n",
    "        input_mask = tf.cast(input_word_ids > 0, tf.int64)\n",
    "        input_mask = tf.reshape(input_mask, [-1, MAX_SEQ_LEN])\n",
    "        \n",
    "        zeros_dims = tf.stack(tf.shape(input_mask))\n",
    "        input_type_ids = tf.fill(zeros_dims, 0)\n",
    "        input_type_ids = tf.cast(input_type_ids, tf.int64)\n",
    "\n",
    "        return (\n",
    "            input_word_ids, \n",
    "            input_mask,\n",
    "            input_type_ids\n",
    "        )\n",
    "\n",
    "    input_word_ids, input_mask, input_type_ids = \\\n",
    "        preprocess_bert_input(tf.squeeze(inputs['text'], axis=1))\n",
    "\n",
    "    return {\n",
    "        'input_word_ids': input_word_ids,\n",
    "        'input_mask': input_mask,\n",
    "        'input_type_ids': input_type_ids,\n",
    "        'label': inputs['label']\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Sz5cevHYrR6M"
   },
   "outputs": [],
   "source": [
    "transform = Transform(\n",
    "    examples=example_gen.outputs['examples'],\n",
    "    schema=schema_gen.outputs['schema'],\n",
    "    module_file=os.path.abspath(\"transform.py\"))\n",
    "context.run(transform)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "yCcNJxWSKPPv"
   },
   "source": [
    "#### Check the Output Data Struture of the TF Transform Operation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "PcUEGmuhtmGi"
   },
   "outputs": [],
   "source": [
    "from tfx_bsl.coders.example_coder import ExampleToNumpyDict\n",
    "\n",
    "pp = pprint.PrettyPrinter()\n",
    "\n",
    "# Get the URI of the output artifact representing the transformed examples, which is a directory\n",
    "train_uri = transform.outputs['transformed_examples'].get()[0].uri\n",
    "\n",
    "print(train_uri)\n",
    "\n",
    "# Get the list of files in this directory (all compressed TFRecord files)\n",
    "tfrecord_folders = [os.path.join(train_uri, name) for name in os.listdir(train_uri)]\n",
    "tfrecord_filenames = []\n",
    "for tfrecord_folder in tfrecord_folders:\n",
    "    for name in os.listdir(tfrecord_folder):\n",
    "        tfrecord_filenames.append(os.path.join(tfrecord_folder, name))\n",
    "\n",
    "\n",
    "# Create a TFRecordDataset to read these files\n",
    "dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type=\"GZIP\")\n",
    "\n",
    "for tfrecord in dataset.take(1):\n",
    "    serialized_example = tfrecord.numpy()\n",
    "    example = ExampleToNumpyDict(serialized_example)\n",
    "    pp.pprint(example)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MdsXSG52kVoL"
   },
   "source": [
    "## Training of the Keras Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ywjksr-vtxrX"
   },
   "outputs": [],
   "source": [
    "%%skip_for_export\n",
    "%%writefile trainer.py\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow_hub as hub\n",
    "import tensorflow_model_analysis as tfma\n",
    "import tensorflow_transform as tft\n",
    "from tensorflow_transform.tf_metadata import schema_utils\n",
    "\n",
    "from typing import Text\n",
    "\n",
    "import absl\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "import tensorflow_transform as tft\n",
    "from tfx.components.trainer.executor import TrainerFnArgs\n",
    "\n",
    "\n",
    "_LABEL_KEY = 'label'\n",
    "BERT_TFHUB_URL = \"https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2\"\n",
    "\n",
    "\n",
    "def _gzip_reader_fn(filenames):\n",
    "    \"\"\"Small utility returning a record reader that can read gzip'ed files.\"\"\"\n",
    "    return tf.data.TFRecordDataset(filenames, compression_type='GZIP')\n",
    "\n",
    "def load_bert_layer(model_url=BERT_TFHUB_URL):\n",
    "    # Load the pre-trained BERT model as layer in Keras\n",
    "    bert_layer = hub.KerasLayer(\n",
    "        handle=model_url,\n",
    "        trainable=False)  # model can be fine-tuned \n",
    "    return bert_layer\n",
    "\n",
    "def get_model(tf_transform_output, max_seq_length=64, num_labels=2):\n",
    "\n",
    "    # dynamically create inputs for all outputs of our transform graph\n",
    "    feature_spec = tf_transform_output.transformed_feature_spec()  \n",
    "    feature_spec.pop(_LABEL_KEY)\n",
    "\n",
    "    inputs = {\n",
    "        key: tf.keras.layers.Input(shape=(max_seq_length), name=key, dtype=tf.int64)\n",
    "            for key in feature_spec.keys()\n",
    "    }\n",
    "\n",
    "    input_word_ids = tf.cast(inputs[\"input_word_ids\"], dtype=tf.int32)\n",
    "    input_mask = tf.cast(inputs[\"input_mask\"], dtype=tf.int32)\n",
    "    input_type_ids = tf.cast(inputs[\"input_type_ids\"], dtype=tf.int32)\n",
    "\n",
    "    bert_layer = load_bert_layer()\n",
    "    pooled_output, _ = bert_layer(\n",
    "        [input_word_ids, \n",
    "         input_mask, \n",
    "         input_type_ids\n",
    "        ]\n",
    "    )\n",
    "    \n",
    "    # Add additional layers depending on your problem\n",
    "    x = tf.keras.layers.Dense(256, activation='relu')(pooled_output)\n",
    "    dense = tf.keras.layers.Dense(64, activation='relu')(x)\n",
    "    pred = tf.keras.layers.Dense(1, activation='sigmoid')(dense)\n",
    "\n",
    "    keras_model = tf.keras.Model(\n",
    "        inputs=[\n",
    "                inputs['input_word_ids'], \n",
    "                inputs['input_mask'], \n",
    "                inputs['input_type_ids']], \n",
    "        outputs=pred)\n",
    "    keras_model.compile(loss='binary_crossentropy', \n",
    "                        optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), \n",
    "                        metrics=['accuracy']\n",
    "                        )\n",
    "    return keras_model\n",
    "\n",
    "\n",
    "def _get_serve_tf_examples_fn(model, tf_transform_output):\n",
    "    \"\"\"Returns a function that parses a serialized tf.Example and applies TFT.\"\"\"\n",
    "\n",
    "    model.tft_layer = tf_transform_output.transform_features_layer()\n",
    "\n",
    "    @tf.function\n",
    "    def serve_tf_examples_fn(serialized_tf_examples):\n",
    "        \"\"\"Returns the output to be used in the serving signature.\"\"\"\n",
    "        feature_spec = tf_transform_output.raw_feature_spec()\n",
    "        feature_spec.pop(_LABEL_KEY)\n",
    "        parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)\n",
    "\n",
    "        transformed_features = model.tft_layer(parsed_features)\n",
    "\n",
    "        outputs = model(transformed_features)\n",
    "        return {'outputs': outputs}\n",
    "\n",
    "    return serve_tf_examples_fn\n",
    "\n",
    "def _input_fn(file_pattern: Text,\n",
    "              tf_transform_output: tft.TFTransformOutput,\n",
    "              batch_size: int = 32) -> tf.data.Dataset:\n",
    "    \"\"\"Generates features and label for tuning/training.\n",
    "\n",
    "    Args:\n",
    "      file_pattern: input tfrecord file pattern.\n",
    "      tf_transform_output: A TFTransformOutput.\n",
    "      batch_size: representing the number of consecutive elements of returned\n",
    "        dataset to combine in a single batch\n",
    "\n",
    "    Returns:\n",
    "      A dataset that contains (features, indices) tuple where features is a\n",
    "        dictionary of Tensors, and indices is a single Tensor of label indices.\n",
    "    \"\"\"\n",
    "    transformed_feature_spec = (\n",
    "        tf_transform_output.transformed_feature_spec().copy())\n",
    "\n",
    "    dataset = tf.data.experimental.make_batched_features_dataset(\n",
    "        file_pattern=file_pattern,\n",
    "        batch_size=batch_size,\n",
    "        features=transformed_feature_spec,\n",
    "        reader=_gzip_reader_fn,\n",
    "        label_key=_LABEL_KEY)\n",
    "\n",
    "    return dataset\n",
    "\n",
    "# TFX Trainer will call this function.\n",
    "def run_fn(fn_args: TrainerFnArgs):\n",
    "    \"\"\"Train the model based on given args.\n",
    "\n",
    "    Args:\n",
    "      fn_args: Holds args used to train the model as name/value pairs.\n",
    "    \"\"\"\n",
    "    tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)\n",
    "\n",
    "    train_dataset = _input_fn(fn_args.train_files, tf_transform_output, 32)\n",
    "    eval_dataset = _input_fn(fn_args.eval_files, tf_transform_output, 32)\n",
    "\n",
    "    mirrored_strategy = tf.distribute.MirroredStrategy()\n",
    "    with mirrored_strategy.scope():\n",
    "        model = get_model(tf_transform_output=tf_transform_output)\n",
    "\n",
    "    model.fit(\n",
    "        train_dataset,\n",
    "        steps_per_epoch=fn_args.train_steps,\n",
    "        validation_data=eval_dataset,\n",
    "        validation_steps=fn_args.eval_steps)\n",
    "\n",
    "    signatures = {\n",
    "        'serving_default':\n",
    "            _get_serve_tf_examples_fn(model,\n",
    "                                      tf_transform_output).get_concrete_function(\n",
    "                                          tf.TensorSpec(\n",
    "                                              shape=[None],\n",
    "                                              dtype=tf.string,\n",
    "                                              name='examples')),\n",
    "    }\n",
    "    model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "b4n7fkCbnvHW"
   },
   "outputs": [],
   "source": [
    "# NOTE: Adjust the number of training and evaluation steps\n",
    "TRAINING_STEPS = 50\n",
    "EVALUATION_STEPS = 50\n",
    "\n",
    "trainer = Trainer(\n",
    "    module_file=os.path.abspath(\"trainer.py\"),\n",
    "    custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),\n",
    "    examples=transform.outputs['transformed_examples'],\n",
    "    transform_graph=transform.outputs['transform_graph'],\n",
    "    schema=schema_gen.outputs['schema'],\n",
    "    train_args=trainer_pb2.TrainArgs(num_steps=TRAINING_STEPS),\n",
    "    eval_args=trainer_pb2.EvalArgs(num_steps=EVALUATION_STEPS))\n",
    "context.run(trainer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LD2kK5XQenDL"
   },
   "outputs": [],
   "source": [
    "model_resolver = ResolverNode(\n",
    "    instance_name='latest_blessed_model_resolver',\n",
    "    resolver_class=latest_blessed_model_resolver.LatestBlessedModelResolver,\n",
    "    model=Channel(type=Model),\n",
    "    model_blessing=Channel(type=ModelBlessing))\n",
    "\n",
    "context.run(model_resolver)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "SgH50dYW5C2T"
   },
   "source": [
    "## TensorFlow Model Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "NVOPbS9Te6MW"
   },
   "outputs": [],
   "source": [
    "eval_config = tfma.EvalConfig(\n",
    "    model_specs=[\n",
    "        tfma.ModelSpec(label_key='label')\n",
    "    ],\n",
    "    metrics_specs=[\n",
    "        tfma.MetricsSpec(\n",
    "            metrics=[\n",
    "                tfma.MetricConfig(class_name='ExampleCount')\n",
    "            ],\n",
    "            thresholds = {\n",
    "                'binary_accuracy': tfma.MetricThreshold(\n",
    "                    value_threshold=tfma.GenericValueThreshold(\n",
    "                        lower_bound={'value': 0.5}),\n",
    "                    change_threshold=tfma.GenericChangeThreshold(\n",
    "                       direction=tfma.MetricDirection.HIGHER_IS_BETTER,\n",
    "                       absolute={'value': -1e-10}))\n",
    "            }\n",
    "        )\n",
    "    ],\n",
    "    slicing_specs=[\n",
    "        # An empty slice spec means the overall slice, i.e. the whole dataset.\n",
    "        tfma.SlicingSpec(),\n",
    "    ])\n",
    "\n",
    "evaluator = Evaluator(\n",
    "    examples=example_gen.outputs['examples'],\n",
    "    model=trainer.outputs['model'],\n",
    "    baseline_model=model_resolver.outputs['model'],\n",
    "    eval_config=eval_config\n",
    ")\n",
    "\n",
    "context.run(evaluator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "CD3Q8gnznAnT"
   },
   "outputs": [],
   "source": [
    "# Check the blessing\n",
    "!ls {evaluator.outputs['blessing'].get()[0].uri}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "5f4Z0vJWOIyk"
   },
   "source": [
    "## Model Export for Serving"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "CxxXrsdebY63"
   },
   "outputs": [],
   "source": [
    "!mkdir ./content/serving_model_dir\n",
    "\n",
    "serving_model_dir = \"./content/serving_model_dir\"\n",
    "\n",
    "pusher = Pusher(\n",
    "    model=trainer.outputs['model'],\n",
    "    model_blessing=evaluator.outputs['blessing'],\n",
    "    push_destination=pusher_pb2.PushDestination(\n",
    "        filesystem=pusher_pb2.PushDestination.Filesystem(\n",
    "            base_directory=serving_model_dir)))\n",
    "\n",
    "context.run(pusher)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "WWni3fVVafDa"
   },
   "source": [
    "## Test your Exported Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "bTi19Ojrbumq"
   },
   "outputs": [],
   "source": [
    "def _bytes_feature(value):\n",
    "    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))\n",
    "\n",
    "push_uri = pusher.outputs.model_push.get()[0].uri\n",
    "latest_version = max(os.listdir(push_uri))\n",
    "latest_version_path = os.path.join(push_uri, latest_version)\n",
    "loaded_model = tf.saved_model.load(latest_version_path)\n",
    "\n",
    "example_str = b\"This is the finest show ever produced for TV. Each episode is a triumph. The casting, the writing, the timing are all second to none. This cast performs miracles.\"\n",
    "example = tf.train.Example(features=tf.train.Features(feature={\n",
    "    'text': _bytes_feature(example_str)}))\n",
    "\n",
    "serialized_example = example.SerializeToString()\n",
    "f = loaded_model.signatures[\"serving_default\"]\n",
    "print(f(tf.constant([serialized_example])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Y-Yr3qVov33c"
   },
   "outputs": [],
   "source": [
    "def _bytes_feature(value):\n",
    "    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))\n",
    "\n",
    "push_uri = pusher.outputs.model_push.get()[0].uri\n",
    "latest_version = max(os.listdir(push_uri))\n",
    "latest_version_path = os.path.join(push_uri, latest_version)\n",
    "loaded_model = tf.saved_model.load(latest_version_path)\n",
    "\n",
    "example_str = b\"I loved it!\"\n",
    "example = tf.train.Example(features=tf.train.Features(feature={\n",
    "    'text': _bytes_feature(example_str)}))\n",
    "\n",
    "serialized_example = example.SerializeToString()\n",
    "f = loaded_model.signatures[\"serving_default\"]\n",
    "print(f(tf.constant([serialized_example])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "mRpBI4Ojw_hT"
   },
   "outputs": [],
   "source": [
    "def _bytes_feature(value):\n",
    "    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))\n",
    "\n",
    "push_uri = pusher.outputs.model_push.get()[0].uri\n",
    "latest_version = max(os.listdir(push_uri))\n",
    "latest_version_path = os.path.join(push_uri, latest_version)\n",
    "loaded_model = tf.saved_model.load(latest_version_path)\n",
    "\n",
    "example_str = b\"It's OK.\"\n",
    "example = tf.train.Example(features=tf.train.Features(feature={\n",
    "    'text': _bytes_feature(example_str)}))\n",
    "\n",
    "serialized_example = example.SerializeToString()\n",
    "f = loaded_model.signatures[\"serving_default\"]\n",
    "print(f(tf.constant([serialized_example])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Fxy84A2sxAba"
   },
   "outputs": [],
   "source": [
    "def _bytes_feature(value):\n",
    "    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))\n",
    "\n",
    "push_uri = pusher.outputs.model_push.get()[0].uri\n",
    "latest_version = max(os.listdir(push_uri))\n",
    "latest_version_path = os.path.join(push_uri, latest_version)\n",
    "loaded_model = tf.saved_model.load(latest_version_path)\n",
    "\n",
    "example_str = b\"The worst product ever.\"\n",
    "example = tf.train.Example(features=tf.train.Features(feature={\n",
    "    'text': _bytes_feature(example_str)}))\n",
    "\n",
    "serialized_example = example.SerializeToString()\n",
    "f = loaded_model.signatures[\"serving_default\"]\n",
    "print(f(tf.constant([serialized_example])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Model has been exported to {}'.format(pusher.outputs.model_push.get()[0].uri))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for path in os.walk('{}/'.format(pusher.outputs.model_push.get()[0].uri)):\n",
    "    print(path[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%javascript\n",
    "Jupyter.notebook.save_checkpoint();\n",
    "Jupyter.notebook.session.delete();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "machine_shape": "hm",
   "name": "TFX_Pipeline_for_Bert_Preprocessing.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
